# hierarchical_lasso.py
"""
Hierarchical Lasso implementation for finding effective interaction pairs.

This module implements the Strong-hierarchy lasso for pairwise interactions
based on Bien et al., 2013. The implementation uses CVXPY for convex optimization.
"""

from typing import Optional, Dict, Any, List, Tuple, Union
import warnings

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

try:
    import cvxpy as cp
except ImportError:
    raise ImportError(
        "cvxpy is required for HierarchicalLasso. Install it with: pip install cvxpy"
    )

class HierarchicalLasso:
    """
    Strong-hierarchy lasso for pairwise interactions (Bien et al., 2013).
    
    Solves the convex relaxation with symmetry and hierarchy constraints to find
    effective interaction pairs using LASSO regularization.

    Objective (quadratic loss + l1 + tiny ridge):
      min_{beta+, beta-, Theta, beta0}
        (1/(2n)) * || y - (beta0 + X @ (beta+ - beta-) + 0.5 * q(X, Theta)) ||_2^2
        + lam * (sum(beta+) + sum(beta-) + 0.5 * ||Theta||_1)
        + eps * 0.5 * (||beta+||_2^2 + ||beta-||_2^2 + ||Theta||_F^2)

    Subject to:
        beta+ >= 0, beta- >= 0, Theta == Theta.T, diag(Theta) == 0,
        sum_k |Theta_{j,k}|  <= beta+_j + beta-_j  (for all j)

    where q(X, Theta)_n = trace(Theta @ (x_n x_n^T)) = sum_{i,k} Theta_{i,k} x_{n,i} x_{n,k}.
    
    Parameters
    ----------
    lam : float, default=0.1
        L1 regularization strength (lambda).
    eps : float, default=1e-8
        Tiny ridge (elastic-net) term for uniqueness/stability.
    solver : str or None, default=None
        CVXPY solver name (e.g., "OSQP", "ECOS", "SCS"). If None, CVXPY picks automatically.
    solver_opts : dict or None, default=None
        Additional solver options.
    fit_intercept : bool, default=True
        Whether to fit an intercept term.
    standardize : bool, default=True
        Whether to standardize features (mean 0, std 1) and center target.
    random_state : int or None, default=None
        Random seed for reproducible results.
    """
    
    def __init__(
        self,
        lam: float = 0.1,
        eps: float = 1e-8,
        solver: Optional[str] = None,
        solver_opts: Optional[Dict[str, Any]] = None,
        fit_intercept: bool = True,
        standardize: bool = True,
        random_state: Optional[int] = None,
    ) -> None:
        # Validate parameters
        if lam < 0:
            raise ValueError(f"lam must be non-negative, got {lam}")
        if eps < 0:
            raise ValueError(f"eps must be non-negative, got {eps}")
        
        self.lam = float(lam)
        self.eps = float(eps)
        self.solver = solver
        self.solver_opts = solver_opts or {}
        self.fit_intercept = fit_intercept
        self.standardize = standardize
        self.random_state = random_state

        # Learned parameters (will be set during fit)
        self.beta0_: Optional[float] = None
        self.beta_plus_: Optional[np.ndarray] = None
        self.beta_minus_: Optional[np.ndarray] = None
        self.theta_: Optional[np.ndarray] = None  # full p x p matrix (symmetric, zero diagonal)
        self.intercept_: Optional[float] = None

        # Preprocessing attributes
        self.x_scaler_: Optional[StandardScaler] = None
        self.y_mean_: Optional[float] = None
        self._is_fitted: bool = False

    def _prepare(self, X: Union[np.ndarray, List], y: Union[np.ndarray, List]) -> Tuple[np.ndarray, np.ndarray, int, int]:
        """
        Prepare and validate input data for fitting.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Input features.
        y : array-like of shape (n_samples,)
            Target values.
            
        Returns
        -------
        Xs : np.ndarray
            Processed features.
        ys : np.ndarray
            Processed target.
        n : int
            Number of samples.
        p : int
            Number of features.
        """
        try:
            X = np.asarray(X, dtype=float)
            y = np.asarray(y, dtype=float).reshape(-1)
        except (ValueError, TypeError) as e:
            raise ValueError(f"Could not convert input to numeric arrays: {e}")
        
        n, p = X.shape
        
        # Validate input shapes
        if len(y) != n:
            raise ValueError(f"X and y have incompatible shapes: X has {n} samples, y has {len(y)}")
        if n < p:
            warnings.warn(f"Number of samples ({n}) is less than number of features ({p}). "
                         "This may lead to overfitting.", UserWarning)
        if p < 2:
            raise ValueError(f"Need at least 2 features for interaction modeling, got {p}")

        if self.standardize:
            self.x_scaler_ = StandardScaler(with_mean=True, with_std=True)
            Xs = self.x_scaler_.fit_transform(X)
        else:
            self.x_scaler_ = None
            Xs = X.copy()

        # Center y (intercept handled separately)
        self.y_mean_ = float(np.mean(y))
        ys = y - self.y_mean_

        return Xs, ys, n, p

    def fit(self, X: Union[np.ndarray, List], y: Union[np.ndarray, List], verbose: bool = False) -> 'HierarchicalLasso':
        """
        Fit the hierarchical lasso model.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Training features.
        y : array-like of shape (n_samples,)
            Training target values.
        verbose : bool, default=False
            Whether to print solver output.
            
        Returns
        -------
        self : HierarchicalLasso
            Fitted estimator.
            
        Raises
        ------
        RuntimeError
            If optimization fails to converge.
        """
        Xs, ys, n, p = self._prepare(X, y)

        # Set random seed for reproducible results
        if self.random_state is not None:
            np.random.seed(self.random_state)

        # Variables
        beta_plus = cp.Variable(p, nonneg=True)
        beta_minus = cp.Variable(p, nonneg=True)
        Theta = cp.Variable((p, p))
        if self.fit_intercept:
            beta0 = cp.Variable()
        else:
            beta0 = 0.0  # treated as constant

        # Enforce symmetry and zero diagonal
        constraints = [Theta == Theta.T, cp.diag(Theta) == 0]

        # Strong hierarchy: ||Theta[j, :]||_1 <= beta_plus[j] + beta_minus[j] for all j
        # (exclude diagonal entry in the L1 row norm—it's zero anyway by constraint)
        for j in range(p):
            row_abs_sum = cp.norm1(Theta[j, :])  # diag(Theta)=0, so it's fine
            constraints.append(row_abs_sum <= beta_plus[j] + beta_minus[j])

        # Prediction: yhat = beta0 + Xs @ (beta_plus - beta_minus) + 0.5 * q,
        # where q_n = trace(Theta @ (x_n x_n^T)) = sum_{i,k} Theta_{i,k} x_ni x_nk
        # Implement q as: q = row-wise sum of (X @ Theta) * X
        # Efficiently, M = Xs @ Theta (cp expression), then q = sum over columns of M * Xs
        M = Xs @ Theta                    # (n x p)
        q = cp.sum(cp.multiply(M, Xs), axis=1)  # elementwise product + row-sum (n,)
        yhat = beta0 + Xs @ (beta_plus - beta_minus) + 0.5 * q

        # Loss function components
        data_fit = (1.0 / (2 * n)) * cp.sum_squares(ys - yhat)
        l1_penalty = self.lam * (cp.sum(beta_plus) + cp.sum(beta_minus) + 0.5 * cp.norm1(Theta))
        ridge_penalty = 0.5 * self.eps * (
            cp.sum_squares(beta_plus) + cp.sum_squares(beta_minus) + cp.sum_squares(Theta)
        )
        objective = data_fit + l1_penalty + ridge_penalty

        # Create and solve the optimization problem
        problem = cp.Problem(cp.Minimize(objective), constraints)

        try:
            if self.solver is not None:
                problem.solve(solver=self.solver, verbose=verbose, **self.solver_opts)
            else:
                problem.solve(verbose=verbose, **self.solver_opts)
        except Exception as e:
            raise RuntimeError(f"Solver failed with error: {e}")

        if problem.status not in ("optimal", "optimal_inaccurate"):
            raise RuntimeError(f"Optimization did not converge: status={problem.status}")

        # Store fitted parameters
        self.beta_plus_ = beta_plus.value
        self.beta_minus_ = beta_minus.value
        self.theta_ = Theta.value
        self.beta0_ = float(beta0.value) if self.fit_intercept else 0.0

        # Calculate intercept in original scale
        self.intercept_ = self.y_mean_ + self.beta0_
        self._is_fitted = True

        return self

    def _check_fitted(self) -> None:
        """Check if the model has been fitted."""
        if not self._is_fitted:
            raise ValueError("This HierarchicalLasso instance is not fitted yet. "
                           "Call 'fit' with appropriate arguments before using this estimator.")

    def coef_main_(self) -> np.ndarray:
        """
        Return main-effect coefficients (beta = beta+ - beta-) in standardized X space.
        
        Returns
        -------
        np.ndarray of shape (n_features,)
            Main effect coefficients.
        """
        self._check_fitted()
        return self.beta_plus_ - self.beta_minus_

    def coef_interactions_(self) -> np.ndarray:
        """
        Return symmetric interaction matrix Theta (p x p) in standardized X space.
        
        Returns
        -------
        np.ndarray of shape (n_features, n_features)
            Symmetric interaction coefficient matrix with zero diagonal.
        """
        self._check_fitted()
        return self.theta_

    def selected_main_indices_(self, tol: float = 1e-8) -> np.ndarray:
        """
        Get indices of selected main effects.
        
        Parameters
        ----------
        tol : float, default=1e-8
            Tolerance threshold for considering a coefficient as non-zero.
            
        Returns
        -------
        np.ndarray
            Indices of features with non-zero main effects.
        """
        if tol < 0:
            raise ValueError(f"Tolerance must be non-negative, got {tol}")
        
        beta = self.coef_main_()
        return np.where(np.abs(beta) > tol)[0]

    def selected_interactions_(self, tol: float = 1e-8) -> List[Tuple[int, int, float]]:
        """
        Get selected interaction pairs.
        
        Parameters
        ----------
        tol : float, default=1e-8
            Tolerance threshold for considering an interaction as non-zero.
            
        Returns
        -------
        List[Tuple[int, int, float]]
            List of (i, j, theta_ij) tuples with i < j and |theta_ij| > tol.
        """
        if tol < 0:
            raise ValueError(f"Tolerance must be non-negative, got {tol}")
        
        Theta = self.coef_interactions_()
        p = Theta.shape[0]
        selected = []
        for i in range(p):
            for j in range(i + 1, p):
                if abs(Theta[i, j]) > tol:
                    selected.append((i, j, Theta[i, j]))
        return selected

    def _transform_X(self, X: Union[np.ndarray, List]) -> np.ndarray:
        """Transform input features using fitted scaler if applicable."""
        try:
            X = np.asarray(X, dtype=float)
        except (ValueError, TypeError) as e:
            raise ValueError(f"Could not convert input to numeric array: {e}")
            
        if self.standardize and self.x_scaler_ is not None:
            return self.x_scaler_.transform(X)
        return X

    def predict(self, X: Union[np.ndarray, List]) -> np.ndarray:
        """
        Predict target values in original scale.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Input features.
            
        Returns
        -------
        np.ndarray of shape (n_samples,)
            Predicted target values.
        """
        self._check_fitted()
        
        Xs = self._transform_X(X)
        n_samples, n_features = Xs.shape
        
        # Validate feature dimensions
        if n_features != self.theta_.shape[0]:
            raise ValueError(f"X has {n_features} features, but model was fitted with "
                           f"{self.theta_.shape[0]} features")
        
        beta = self.coef_main_()
        M = Xs @ self.theta_
        q = np.sum(M * Xs, axis=1)  # elementwise, row-wise sum
        yhat_centered = self.beta0_ + Xs @ beta + 0.5 * q
        return self.y_mean_ + yhat_centered

    def score(self, X: Union[np.ndarray, List], y: Union[np.ndarray, List]) -> float:
        """
        Return the coefficient of determination R² of the prediction.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Test features.
        y : array-like of shape (n_samples,)
            True target values.
            
        Returns
        -------
        float
            R² score.
        """
        try:
            y = np.asarray(y, dtype=float).reshape(-1)
        except (ValueError, TypeError) as e:
            raise ValueError(f"Could not convert y to numeric array: {e}")
            
        y_pred = self.predict(X)
        
        if len(y) != len(y_pred):
            raise ValueError(f"y and predictions have different lengths: {len(y)} vs {len(y_pred)}")
        
        ss_res = np.sum((y - y_pred) ** 2)
        ss_tot = np.sum((y - np.mean(y)) ** 2)
        
        if ss_tot == 0:
            return 1.0 if ss_res == 0 else 0.0
        
        return 1.0 - (ss_res / ss_tot)

    def score_mse(self, X: Union[np.ndarray, List], y: Union[np.ndarray, List]) -> float:
        """
        Return the mean squared error of the prediction.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Test features.
        y : array-like of shape (n_samples,)
            True target values.
            
        Returns
        -------
        float
            Mean squared error.
        """
        try:
            y = np.asarray(y, dtype=float).reshape(-1)
        except (ValueError, TypeError) as e:
            raise ValueError(f"Could not convert y to numeric array: {e}")
            
        y_pred = self.predict(X)
        
        if len(y) != len(y_pred):
            raise ValueError(f"y and predictions have different lengths: {len(y)} vs {len(y_pred)}")
        
        return float(np.mean((y - y_pred) ** 2))

    @property
    def coef_(self) -> np.ndarray:
        """Main effect coefficients (sklearn-style property)."""
        return self.coef_main_()
    
    @property
    def intercept(self) -> float:
        """Intercept term in original scale (sklearn-style property)."""
        self._check_fitted()
        return self.intercept_
    
    @property
    def interaction_matrix(self) -> np.ndarray:
        """Interaction coefficient matrix (property alias)."""
        return self.coef_interactions_()
    
    def get_feature_names_out(self, input_features: Optional[List[str]] = None) -> List[str]:
        """
        Get output feature names for interaction terms.
        
        Parameters
        ----------
        input_features : list of str or None
            Input feature names. If None, generic names are generated.
            
        Returns
        -------
        List[str]
            Feature names including main effects and selected interactions.
        """
        self._check_fitted()
        
        n_features = self.theta_.shape[0]
        if input_features is None:
            input_features = [f"x{i}" for i in range(n_features)]
        elif len(input_features) != n_features:
            raise ValueError(f"input_features length ({len(input_features)}) does not match "
                           f"number of features ({n_features})")
        
        # Get selected features
        main_indices = self.selected_main_indices_()
        interaction_pairs = self.selected_interactions_()
        
        feature_names = []
        
        # Add main effects
        for idx in main_indices:
            feature_names.append(input_features[idx])
        
        # Add interactions
        for i, j, _ in interaction_pairs:
            feature_names.append(f"{input_features[i]} * {input_features[j]}")
        
        return feature_names

    def summary(self, feature_names: Optional[List[str]] = None, tol: float = 1e-8) -> str:
        """
        Generate a summary of the fitted model.
        
        Parameters
        ----------
        feature_names : list of str or None
            Feature names for display.
        tol : float, default=1e-8
            Tolerance for determining non-zero coefficients.
            
        Returns
        -------
        str
            Model summary string.
        """
        self._check_fitted()
        
        n_features = self.theta_.shape[0]
        if feature_names is None:
            feature_names = [f"Feature_{i}" for i in range(n_features)]
        
        lines = []
        lines.append("Hierarchical Lasso Model Summary")
        lines.append("=" * 40)
        lines.append(f"Regularization strength (lambda): {self.lam}")
        lines.append(f"Ridge parameter (eps): {self.eps}")
        lines.append(f"Standardized features: {self.standardize}")
        lines.append(f"Fit intercept: {self.fit_intercept}")
        lines.append("")
        
        # Main effects
        main_indices = self.selected_main_indices_(tol=tol)
        main_coefs = self.coef_main_()
        lines.append(f"Selected Main Effects ({len(main_indices)}):")
        if len(main_indices) > 0:
            for idx in main_indices:
                lines.append(f"  {feature_names[idx]:20s}: {main_coefs[idx]:8.4f}")
        else:
            lines.append("  None")
        lines.append("")
        
        # Interactions
        interactions = self.selected_interactions_(tol=tol)
        lines.append(f"Selected Interactions ({len(interactions)}):")
        if len(interactions) > 0:
            for i, j, coef in interactions:
                pair_name = f"{feature_names[i]} * {feature_names[j]}"
                lines.append(f"  {pair_name:20s}: {coef:8.4f}")
        else:
            lines.append("  None")
        
        return "\n".join(lines)


def get_lasso_interactions(
    X: Union[np.ndarray, pd.DataFrame],
    y: Union[np.ndarray, List],
    lam: float = 0.1,
    feature_names: Optional[List[str]] = None,
    tol: float = 1e-8,
    **lasso_kwargs
) -> pd.DataFrame:
    """
    Get ranked interaction dataframe from hierarchical lasso method.
    
    This function fits a hierarchical lasso model and returns a DataFrame containing
    all feature pairs ranked by their interaction strength, compatible with the 
    format from get_shap_mean_baseline and get_improved_interactions.
    
    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        Input features.
    y : array-like of shape (n_samples,)
        Target values.
    lam : float, default=0.1
        L1 regularization strength for the lasso.
    feature_names : list of str or None, default=None
        Feature names. If None, will use DataFrame columns or generate default names.
    tol : float, default=1e-8
        Tolerance for considering interactions as non-zero.
    **lasso_kwargs
        Additional arguments passed to HierarchicalLasso constructor.
        
    Returns
    -------
    pd.DataFrame
        DataFrame with columns:
        - i: index of first feature
        - j: index of second feature  
        - feature_i: name of first feature
        - feature_j: name of second feature
        - mean_interaction: interaction coefficient value
        - mean_abs_interaction: absolute value of interaction coefficient
        Sorted by mean_abs_interaction in descending order.
        
    Examples
    --------
    >>> X, y, _, _ = generate_synthetic_data(n_samples=100, n_features=5)
    >>> df = get_lasso_interactions(X, y, lam=0.1)
    >>> print(df.head())
    """
    # Handle feature names
    if feature_names is None:
        if hasattr(X, 'columns'):
            feature_names = list(X.columns)
        else:
            n_features = X.shape[1] if hasattr(X, 'shape') else len(X[0])
            feature_names = [f"x_{i}" for i in range(n_features)]
    
    # Set default lasso parameters if not provided
    default_lasso_kwargs = {
        'eps': 1e-8,
        'standardize': True,
        'fit_intercept': True,
        'random_state': None
    }
    default_lasso_kwargs.update(lasso_kwargs)
    
    # Fit hierarchical lasso model
    model = HierarchicalLasso(lam=lam, **default_lasso_kwargs)
    model.fit(X, y)
    
    # Get interaction matrix
    theta = model.coef_interactions_()
    n_features = len(feature_names)
    
    # Create results list for all feature pairs
    results = []
    
    for i in range(n_features):
        for j in range(i + 1, n_features):
            # Get interaction coefficient (symmetric matrix)
            interaction_coef = theta[i, j]
            
            results.append({
                'i': i,
                'j': j,
                'feature_i': feature_names[i],
                'feature_j': feature_names[j],
                'mean_interaction': float(interaction_coef),
                'mean_abs_interaction': float(np.abs(interaction_coef))
            })
    
    # Convert to DataFrame and sort by absolute interaction strength
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('mean_abs_interaction', ascending=False).reset_index(drop=True)
    
    return results_df


def get_lasso_interactions_filtered(
    X: Union[np.ndarray, pd.DataFrame],
    y: Union[np.ndarray, List],
    lam: float = 0.1,
    feature_names: Optional[List[str]] = None,
    tol: float = 1e-3,
    **lasso_kwargs
) -> pd.DataFrame:
    """
    Get ranked interaction dataframe with only meaningful (non-zero) interactions.
    
    This is a convenience function that filters the results from get_lasso_interactions()
    to only include interactions above the tolerance threshold.
    
    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        Input features.
    y : array-like of shape (n_samples,)
        Target values.
    lam : float, default=0.1
        L1 regularization strength for the lasso.
    feature_names : list of str or None, default=None
        Feature names. If None, will use DataFrame columns or generate default names.
    tol : float, default=1e-3
        Minimum absolute interaction coefficient to include in results.
    **lasso_kwargs
        Additional arguments passed to HierarchicalLasso constructor.
        
    Returns
    -------
    pd.DataFrame
        Filtered DataFrame with only meaningful interactions, sorted by strength.
    """
    # Get all interactions
    df = get_lasso_interactions(X, y, lam=lam, feature_names=feature_names, **lasso_kwargs)
    
    # Filter by tolerance
    filtered_df = df[df['mean_abs_interaction'] >= tol].reset_index(drop=True)
    
    return filtered_df


def generate_synthetic_data(
    n_samples: int = 200,
    n_features: int = 8,
    n_main_effects: int = 2,
    n_interactions: int = 1,
    noise_std: float = 0.5,
    random_state: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Generate synthetic data for testing hierarchical lasso.
    
    Parameters
    ----------
    n_samples : int, default=200
        Number of samples to generate.
    n_features : int, default=8
        Number of features.
    n_main_effects : int, default=2
        Number of non-zero main effects.
    n_interactions : int, default=1
        Number of non-zero interactions.
    noise_std : float, default=0.5
        Standard deviation of noise.
    random_state : int or None, default=None
        Random seed.
        
    Returns
    -------
    X : np.ndarray
        Feature matrix.
    y : np.ndarray
        Target values.
    beta_true : np.ndarray
        True main effect coefficients.
    Theta_true : np.ndarray
        True interaction matrix.
    """
    rng = np.random.default_rng(random_state)
    
    # Generate features
    X = rng.normal(size=(n_samples, n_features))
    
    # Generate true coefficients
    beta_true = np.zeros(n_features)
    main_indices = rng.choice(n_features, size=n_main_effects, replace=False)
    beta_true[main_indices] = rng.normal(scale=1.0, size=n_main_effects)
    
    # Generate true interactions
    Theta_true = np.zeros((n_features, n_features))
    for _ in range(n_interactions):
        i, j = rng.choice(n_features, size=2, replace=False)
        if i > j:
            i, j = j, i
        coef = rng.normal(scale=1.0)
        Theta_true[i, j] = Theta_true[j, i] = coef
    
    # Generate target
    y = (X @ beta_true + 
         0.5 * np.sum((X @ Theta_true) * X, axis=1) + 
         noise_std * rng.normal(size=n_samples))
    
    return X, y, beta_true, Theta_true


if __name__ == "__main__":
    # Quick sanity test on synthetic data
    print("Testing HierarchicalLasso with synthetic data...")
    
    X, y, beta_true, Theta_true = generate_synthetic_data(
        n_samples=200, n_features=8, random_state=42
    )
    
    # Note: cvxpy may not be installed, so we'll use a fallback solver
    model = HierarchicalLasso(
        lam=0.15, 
        eps=1e-8, 
        solver="OSQP", 
        solver_opts={"max_iter": 20000},
        random_state=42
    )
    
    try:
        model.fit(X, y, verbose=False)
        
        print("\nModel Summary:")
        print(model.summary())
        print(f"\nR² Score: {model.score(X, y):.4f}")
        print(f"MSE: {model.score_mse(X, y):.4f}")
        
    except ImportError as e:
        print(f"Cannot run test due to missing dependency: {e}")
        print("Install cvxpy with: pip install cvxpy")
    except Exception as e:
        print(f"Error during model fitting: {e}")
